"""
Code is adapted from https://github.com/VainF/DeepLabV3Plus-Pytorch/blob/master/datasets/voc.py
NOTE: This code is unfinished and may not be used.
"""

import os
import tarfile

import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets.utils import download_url

DATASET_YEAR_DICT = {
    "2012": {
        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
        "filename": "VOCtrainval_11-May-2012.tar",
        "md5": "6cd6e144f989b92b3379bac3b3de84fd",
        "base_dir": "VOCdevkit/VOC2012",
    },
    "2011": {
        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
        "filename": "VOCtrainval_25-May-2011.tar",
        "md5": "6c3384ef61512963050cb5d687e5bf1e",
        "base_dir": "TrainVal/VOCdevkit/VOC2011",
    },
    "2010": {
        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
        "filename": "VOCtrainval_03-May-2010.tar",
        "md5": "da459979d0c395079b5c75ee67908abb",
        "base_dir": "VOCdevkit/VOC2010",
    },
    "2009": {
        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
        "filename": "VOCtrainval_11-May-2009.tar",
        "md5": "59065e4b188729180974ef6572f6a212",
        "base_dir": "VOCdevkit/VOC2009",
    },
    "2008": {
        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
        "filename": "VOCtrainval_11-May-2012.tar",
        "md5": "2629fa636546599198acfcfbfcf1904a",
        "base_dir": "VOCdevkit/VOC2008",
    },
    "2007": {
        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
        "filename": "VOCtrainval_06-Nov-2007.tar",
        "md5": "c52e279531787c972589f7e41ab4ae64",
        "base_dir": "VOCdevkit/VOC2007",
    },
}


def voc_cmap(N=256, normalized=False):
    def bitget(byteval, idx):
        return (byteval & (1 << idx)) != 0

    dtype = "float32" if normalized else "uint8"
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7 - j)
            g = g | (bitget(c, 1) << 7 - j)
            b = b | (bitget(c, 2) << 7 - j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap / 255 if normalized else cmap
    return cmap


class VOCSegmentation(data.Dataset):
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
    Args:
        root (string): Root directory of the VOC Dataset.
        year (string, optional): The dataset year, supports years 2007 to 2012.
        image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
    """

    cmap = voc_cmap()

    def __init__(
        self,
        root,
        year="2010",
        image_set="train",
        download=False,
        transform=None,
    ):

        is_aug = False
        if year == "2012_aug":
            is_aug = True
            year = "2012"

        self.root = os.path.expanduser(root)
        self.year = year
        self.url = DATASET_YEAR_DICT[year]["url"]
        self.filename = DATASET_YEAR_DICT[year]["filename"]
        self.md5 = DATASET_YEAR_DICT[year]["md5"]
        self.transform = transform

        self.image_set = image_set
        base_dir = DATASET_YEAR_DICT[year]["base_dir"]
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, "JPEGImages")

        if download:
            download_extract(self.url, self.root, self.filename, self.md5)

        if not os.path.isdir(voc_root):
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it"
            )

        if is_aug and image_set == "train":
            mask_dir = os.path.join(voc_root, "SegmentationClassAug")
            message = "SegmentationClassAug not found, please refer to README.md and prepare it manually"
            assert os.path.exists(mask_dir), message
            split_f = os.path.join(
                self.root, "train_aug.txt"
            )  # './datasets/data/train_aug.txt'
        else:
            mask_dir = os.path.join(voc_root, "SegmentationClass")
            splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
            split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")

        if not os.path.exists(split_f):
            raise ValueError(
                'Wrong image_set entered! Please use image_set="train" '
                'or image_set="trainval" or image_set="val"'
            )

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert len(self.images) == len(self.masks)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.masks[index])
        if self.transform is not None:
            img, target = self.transform(img, target)

        return img, target

    def __len__(self):
        return len(self.images)

    @classmethod
    def decode_target(cls, mask):
        """decode semantic mask to RGB image"""
        return cls.cmap[mask]


def download_extract(url, root, filename, md5):
    download_url(url, root, filename, md5)
    with tarfile.open(os.path.join(root, filename), "r") as tar:
        tar.extractall(path=root)


def get_loader_sampler(root, transform, args, mode):
    """dist eval introduces slight variance to results if
    num samples != 0 mod (batch size * num gpus)
    """
    dataset = VOCSegmentation(os.path.join(root, mode), transform=transform)

    if args.distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    else:
        sampler = None
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=(sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=sampler,
        drop_last=True,
    )

    return loader, sampler


def load_pascal_voc(args):
    input_size = PASCAL_VOC["input_dim"][-1]
    # train_transform_list = [
    #     transforms.RandomResizedCrop(input_size, scale=(0.5, 1.0)),
    #     transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    #     transforms.RandomHorizontalFlip(),
    # ]
    # val_transform_list = [
    #     transforms.Resize(int(input_size * 256 / 224)),
    #     transforms.CenterCrop(input_size),
    # ]
    train_transform_list = [
        transforms.RandomResizedCrop(input_size, scale=(0.5, 1.0)),
        # transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
        transforms.RandomHorizontalFlip(),
    ]
    val_transform_list = [
        transforms.Resize(int(input_size)),
        transforms.CenterCrop(input_size),
    ]
    train_transform_list.append(transforms.ToTensor())
    val_transform_list.append(transforms.ToTensor())
    train_transform = transforms.Compose(train_transform_list)
    val_transform = transforms.Compose(val_transform_list)

    train_loader, train_sampler = get_loader_sampler(
        args.data, train_transform, args, "train"
    )
    val_loader, _ = get_loader_sampler(args.data, val_transform, args, "val")

    return train_loader, train_sampler, val_loader


PASCAL_VOC = {
    "normalize": {"mean": None, "std": None},
    "loader": load_pascal_voc,
    "input_dim": (3, 128, 128),
}
